import argparse
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image
from torchvision.utils import save_image
from lavis.models import load_model_and_preprocess
from blip_utils import visual_attacker


def parse_args():

    parser = argparse.ArgumentParser(description="Demo")

    parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
    parser.add_argument("--n_iters", type=int, default=5001, help="specify the number of iterations for attack.")
    parser.add_argument('--eps', type=int, default=16, help="epsilon of the attack budget")
    parser.add_argument('--alpha', type=int, default=1, help="step_size of the attack")
    parser.add_argument("--constrained", default=False, action='store_true')
    parser.add_argument("--batch_size", type=int, default=1, help="specify the batch size for imagenet.")
    parser.add_argument("--ours", default=True, action='store_true') 
    parser.add_argument('--th', type=int, default=0.1, help="tau value")

    parser.add_argument("--save_dir", type=str, default='defalut',
                        help="save directory")

    args = parser.parse_args()
    return args

args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)



    
# ========================================
#             Model Initialization
# ========================================



print('>>> Initializing Models')

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

# remember to modify the parameter llm_model in ./lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml to the path that store the vicuna weights
model, vis_processor, _ = load_model_and_preprocess(
        name='blip2_vicuna_instruct',
        model_type='vicuna13b',
        is_eval=True,
        device=device,
    )
model.eval()
"""
Source code of the model in:
    ./lavis/models/blip2_models/blip2_vicuna_instruct.py
"""

print('[Initialization Finished]\n')


if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

import csv

file = open("harmful_corpus/derogatory_corpus.csv", "r")
data = list(csv.reader(file, delimiter=","))
file.close()
targets = []
num = len(data)
for i in range(num):
    targets.append(data[i][0])


print(args.save_dir)
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)


my_attacker = visual_attacker.Attacker(args, model, targets, device=model.device, is_rtp=False)

template_img = 'adversarial_images/n02510455_405.JPEG'
img = Image.open(template_img).convert('RGB')
img = vis_processor["eval"](img).unsqueeze(0).to(device)

adv_img_prompt = my_attacker.targeted_attack_B2H(img=img, batch_size = 1,
                                                num_iter=args.n_iters, alpha=args.alpha/255,
                                                epsilon=args.eps / 255)


print('[Done]')